from enum import Enum
import json
import xmltodict
from typing import *
from Grammar import *
from pathlib import Path
from Parser import parse, ParsingException


ignore = ["@xmlns", "@priority", "@id", "xm:sqref", "@indexed", "@rank", "@patternType"]
optional_keys = ["@rgb", "@type", "xm:f"]


RevisionTypeMapping = {
    "Predicate.TopBottom": "top10",
    "Predicate.Contains": "containsText",
    "Predicate.EqualTo": "cellIs",
    "Predicate.GreaterThan": "cellIs",
    "Predicate.GreaterThanIrEqual": "cellIs",
    "Predicate.LessThan": "cellIs",
    "Predicate.LessThan": "cellIs",
}


def RevisionTypeConverter(function: str):
    if function in RevisionTypeMapping:
        return RevisionTypeMapping[function]
    return function.split(".")[1]


class MatchType(Enum):
    Execution = 1
    PartialExecution = 2
    Function = 3
    Fail = 4
    Error = 5


def CheckFunctionMatch(pred: OfficeJSPredicate, RevisionRecords: Dict[str, str]):
    if "@type" in RevisionRecords:
        function = RevisionRecords["@type"]
        for func in function:
            if RevisionTypeConverter(str(pred.pred)).lower() in func.lower():
                return True
            for eq_set in EquivalentSet:
                if pred.pred in eq_set:
                    for elem in eq_set:
                        if RevisionTypeConverter(str(elem)).lower() in func.lower():
                            return True
        print(f"{function} -- {pred.pred}")
        return False
    return False


def Evaluate(officeJS: str, RevisionRecords: Dict[str, str]) -> MatchType:
    try:
        parsed_office_script = parse(officeJS)
    except ParsingException as e:
        # print(e)
        return MatchType.Error

    if CheckFunctionMatch(parsed_office_script, RevisionRecords):
        del RevisionRecords["@type"]
        for key, value in RevisionRecords.items():
            for val in value:
                if val not in officeJS and key != '@rgb':
                    if key not in ignore and key not in optional_keys:
                        return MatchType.Function
                    if key in optional_keys:
                        return MatchType.PartialExecution              
        return MatchType.Execution
    return MatchType.Fail


def extract_string_and_numeric_pairs(d):
    result = []

    if isinstance(d, dict):
        for key, value in d.items():
            if isinstance(key, (str, int, float)) and isinstance(value, (str, int, float)):
                if any(k in key for k in ignore):
                    continue
                result.append((key, value))
            else:
                result.extend(extract_string_and_numeric_pairs(value))

    elif isinstance(d, list):
        for e in d:
            result.extend(extract_string_and_numeric_pairs(e))
    
    return result

    # for key, value in d.items():
    #     if isinstance(key, (str, int, float)) and isinstance(value, (str, int, float)):
    #         if any(k in key for k in ignore):
    #             continue
    #         result.append((key, value))
    #     elif isinstance(value, dict):
    #         nested_result = extract_string_and_numeric_pairs(value)
    #         if nested_result:
    #             result.extend(nested_result)
    # return result


def parse_revision_record(revision_record: str):
    extracted_pairs = []
    revisions = revision_record["revStream"]["xrr"]
    if not isinstance(revisions, list):
        revisions = [revisions]
    for revision in revisions:
        try:
            extracted_pairs.extend(
                extract_string_and_numeric_pairs(revision["co"]["objectState"])
            )
        except:
            continue
    return extracted_pairs


def test(test_dir):
    officejs_scripts = {}
    revision_records = {}
    for file in Path(test_dir).glob("*.js"):
        with open(file, "r") as f:
            officejs_scripts[file.stem] = f.read()
    for file in Path(test_dir).glob("*.rr"):
        with open(file, "r") as f:
            revision_records[file.stem] = f.read()
    for file, officejs_script in officejs_scripts.items():
        if file in revision_records:
            revision_record = revision_records[file]
            evaluation = Evaluate(
                officejs_script, parse_revision_record(revision_record)
            )
            if evaluation is not evaluation.Error:
                print(f"Test Passed For {file} With {evaluation}")
            else:
                print(f"Test Failed for {file}")
                exit(0)
    print("All Tests Passed!")


def test_instruct_excel(
    query_file: str = "../../../Datasets/OfficeScripts/InstructExcelCF.json",
    revrecords_file: str = "revrecords.json",
    outputs_file: str = "../../../Results/IE_CF_test/RetDocOnly_rar_extop10.json"
):

    with open(query_file, "r", encoding="utf-8-sig") as f:
        samples = json.load(f)
    
    with open(revrecords_file, "r", encoding="utf-8-sig") as f:
        revrecords = json.load(f)

    with open(outputs_file, "r") as f:
        outputs = json.load(f)

    matches = []
    failed_indices = []
    for output in (outputs):
        generation = output['code']
        try:
            revrec = [r for r in revrecords if r['index'] == output['index']][0]['revrecords']
        except:
            continue
        extracted_pairs = parse_revision_record(revrec)

        if not extracted_pairs:
            continue
        
        extracted_pairs_dict = {}
        for (k,v) in extracted_pairs:
            if k in extracted_pairs_dict:
                extracted_pairs_dict[k].append(v)
            else:
                extracted_pairs_dict[k] = [v]

        # extracted_pairs_dict = {k: v for (k, v) in extracted_pairs}
        match_type = Evaluate(generation, extracted_pairs_dict)
        matches.append(match_type)
        if match_type in [MatchType.Fail, MatchType.Error]:
            failed_indices.append(output['index'])
    with open("test.json", "w") as file:
        json.dump(failed_indices, file)
    total = len(matches)
    em = matches.count(MatchType.Execution)
    pm = matches.count(MatchType.PartialExecution)
    fm = matches.count(MatchType.Function)
    nm = matches.count(MatchType.Fail)

    execution_match = em/total
    partial_exec_match = (em+pm)/total
    function_match = (em+pm+fm)/total

    print(f"Total: {total}")
    print(f"Execution Success: {execution_match}")
    print(f"Partial Execution Success: {partial_exec_match}")
    print(f"Function Match Success: {function_match}")


if __name__ == "__main__":
    # test()
    test_instruct_excel()
